import numpy as np
import argparse
import tqdm, copy

class Coordinator_baseline:
    def __init__(self, args):
        self.args = args
        self.weights = np.zeros(args.expert_num)
        self.bn = args.be
        self.bs = 2
        self.communication = 0
        self.total_cost = 0

    def dewa(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += server_costs[:, expert_index].sum()

        zi = np.random.binomial(1, self.bn/self.args.expert_num, self.args.expert_num)
        zj = np.random.binomial(1, self.bs/self.args.server_num, self.args.server_num)
        temp_cost = np.zeros(self.args.expert_num)
        for j in range(self.args.server_num):
            temp_cost += zj[j] * server_costs[j]
        l = temp_cost * zi
        l = self.args.expert_num / self.bn / self.bs * self.args.server_num * l
        self.weights += l
        self.communication += (zi.sum() + zj.sum())
        if not self.args.case == 'dewa-p-baseline':
            self.communication += self.args.server_num
        return expert_index

class Coordinator_DEWAP_baseline:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.K = np.ceil(np.log(args.T) * 2).astype(int)
        if args.real:
            self.K = 2
        self.coordinators = []
        for _ in range(self.K):
            self.coordinators.append(Coordinator_baseline(args))
        self.weights = np.zeros(self.K)

    def dewa_p(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        coord_index = np.random.choice(self.K, 1, p=prob)

        for i, coord in enumerate(self.coordinators):
            coord.communication = 0
            temp_expert_index = coord.dewa(server_costs)
            self.weights[i] += server_costs[:, temp_expert_index].sum()
            self.communication += self.args.server_num * 2

            if coord_index == i:
                expert_index = temp_expert_index
                self.total_cost += server_costs[:, expert_index].sum()
            self.communication += coord.communication
        self.communication += self.args.server_num
        return expert_index

class Coordinator:
    def __init__(self, args):
        self.args = args
        self.weights = np.zeros(args.expert_num)
        self.bn = args.be
        self.communication = 0
        self.total_cost = 0

    def dewa(self, sampled_experts, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += server_costs[:, expert_index].sum()

        temp_cost = np.zeros(self.args.expert_num)
        for j in range(self.args.server_num):
            temp_cost += sampled_experts[j]
            self.communication += (sampled_experts[j].sum())
        l = self.args.expert_num / self.bn * temp_cost
        self.weights += l
        if not self.args.case == 'dewa-p':
            self.communication += self.args.server_num
        return expert_index

class Coordinator_DEWAP:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.K = np.ceil(np.log(args.T) * 2).astype(int)
        if args.real:
            self.K = 2
        self.coordinators = []
        for _ in range(self.K):
            self.coordinators.append(Coordinator(args))
        self.weights = np.zeros(self.K)

    def dewa_p(self, server_costs, servers):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        coord_index = np.random.choice(self.K, 1, p=prob)

        for i, coord in enumerate(self.coordinators):
            coord.communication = 0
            sampled_experts = []
            for server in servers:
                server.refresh()
                sampled_experts.append(np.logical_and(server.alpha == 1, server.beta == 1).astype(float))
            temp_expert_index = coord.dewa(sampled_experts, server_costs)
            self.weights[i] += server_costs[:, temp_expert_index].sum()
            self.communication += self.args.server_num * 2

            if coord_index == i:
                expert_index = temp_expert_index
                self.total_cost += server_costs[:, expert_index].sum()
            self.communication += coord.communication
        self.communication += self.args.server_num
        return expert_index

class Coordinator_EWA:
    def __init__(self, args):
        self.args = args
        self.weights = np.zeros(args.expert_num)
        self.communication = 0
        self.total_cost = 0
    
    def ewa(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += server_costs[:, expert_index].sum()

        temp_cost = np.zeros(self.args.expert_num)
        for j in range(self.args.server_num):
            temp_cost += server_costs[j]
            self.communication += self.args.expert_num
        self.communication += self.args.expert_num
        l = temp_cost
        self.weights += l
        return expert_index

class Coordinator_EXP3:
    def __init__(self, args):
        self.args = args
        self.weights = np.ones(args.expert_num)
        self.communication = 0
        self.total_cost = 0
        self.gamma = args.lr

    def exp3(self, server_costs):
        prob = (1 - self.gamma) * self.weights / self.weights.sum() + self.gamma / self.args.expert_num
        new_prob = prob / prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=new_prob)
        self.total_cost += server_costs[:, expert_index].sum()

        x_j = (1 - server_costs[:, expert_index].sum()) / prob[expert_index]
        self.weights[expert_index] = self.weights[expert_index] * np.exp(self.gamma * x_j / self.args.expert_num)
        self.communication += self.args.server_num * 2
        return expert_index

class Coordinator_DEWA_M:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.weights = np.zeros(args.expert_num)
        self.be = args.be
        self.global_li = np.zeros(args.expert_num)
        self.Be = None
    
    def sample(self):
        self.Be = np.random.choice(self.args.expert_num, self.be)

    def receive(self, to_send):
        self.communication += (to_send != 0).sum()
        if not self.args.case == 'dewa-m-p':
            self.communication += self.args.server_num

    def update(self):
        self.global_li = self.args.expert_num / self.args.be * self.global_li
        self.weights = self.weights + self.global_li
        self.global_li = np.zeros(self.args.expert_num)

    def choose(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += np.max(server_costs[:, expert_index])
        return expert_index

class Coordinator_DEWA_M_P:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.K = 3
        if args.real:
            self.K = 2
        self.coordinators = []
        for _ in range(self.K):
            self.coordinators.append(Coordinator_DEWA_M(args))
        self.weights = np.zeros(self.K)

    def dewa_m_p(self, server_costs, servers):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        coord_index = np.random.choice(self.K, 1, p=prob)

        for i, coord in enumerate(self.coordinators):
            coord.communication = 0
            temp_expert_index = coord.choose(server_costs)
            
            coord.sample()
            np.random.shuffle(servers)
            sent = []
            for server in servers:
                sent.append(server.send(coord))
            sent = np.array(sent)
            coord.receive(sent)
            coord.update()

            self.weights[i] += server_costs[:, temp_expert_index].max()
            self.communication += self.args.server_num
            self.communication += (server_costs[:, temp_expert_index] > 0).sum()

            if coord_index == i:
                expert_index = temp_expert_index
                self.total_cost += server_costs[:, expert_index].max()
            self.communication += coord.communication
        self.communication += self.args.server_num

        return expert_index

class Coordinator_EWA_M:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.weights = np.zeros(args.expert_num)

    def choose(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += np.max(server_costs[:, expert_index])
        self.communication += self.args.expert_num * self.args.server_num
        self.communication += self.args.server_num
        l = server_costs.max(axis=0)
        self.weights += l
        return expert_index

class Coordinator_DEWA_M_baseline:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.weights = np.zeros(args.expert_num)
        self.Be = None
        if args.sparse:
            if args.be == 1:
                self.sampled_server_num = int(args.server_num * 0.2)
            else:
                self.sampled_server_num = 1
        elif args.case == 'dewa-m-p-baseline':
            self.sampled_server_num = int(args.server_num * 0.25)
        else:
            self.sampled_server_num = int(args.server_num * 0.75)

    def sample(self):
        self.Be = np.random.choice(self.args.expert_num, self.args.be)

    def choose(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=prob)
        self.total_cost += np.max(server_costs[:, expert_index])

        sampled_server = np.random.choice(self.args.server_num, self.sampled_server_num)
        sampled_server_costs = server_costs[sampled_server]

        self.communication += sampled_server.shape[0] * self.Be.shape[0]
        self.communication += sampled_server.shape[0]
        for i in self.Be:
            self.weights[i] += sampled_server_costs[:, i].max() * self.args.expert_num / self.args.be
        return expert_index

class Coordinator_DEWA_M_P_baseline:
    def __init__(self, args):
        self.args = args
        self.communication = 0
        self.total_cost = 0
        self.K = 3
        if args.real:
            self.K = 2
        self.coordinators = []
        for _ in range(self.K):
            self.coordinators.append(Coordinator_DEWA_M_baseline(args))
        self.weights = np.zeros(self.K)

    def choose(self, server_costs):
        prob = np.exp(-self.args.lr * (self.weights - np.min(self.weights)))
        prob /= prob.sum()
        coord_index = np.random.choice(self.K, 1, p=prob)

        for i, coord in enumerate(self.coordinators):
            coord.communication = 0
            
            coord.sample()
            temp_expert_index = coord.choose(server_costs)

            self.weights[i] += server_costs[:, temp_expert_index].max()
            self.communication += self.args.server_num * 2

            if coord_index == i:
                expert_index = temp_expert_index
                self.total_cost += server_costs[:, expert_index].max()
            self.communication += coord.communication

        return expert_index

class Coordinator_EXP3_M:
    def __init__(self, args):
        self.args = args
        self.weights = np.ones(args.expert_num)
        self.communication = 0
        self.total_cost = 0
        self.gamma = args.lr

    def exp3(self, server_costs):
        prob = (1 - self.gamma) * self.weights / self.weights.sum() + self.gamma / self.args.expert_num
        new_prob = prob / prob.sum()
        expert_index = np.random.choice(self.args.expert_num, 1, p=new_prob)
        self.total_cost += server_costs[:, expert_index].max()

        x_j = (1 - server_costs[:, expert_index].max()) / prob[expert_index]
        self.weights[expert_index] = self.weights[expert_index] * np.exp(self.gamma * x_j / self.args.expert_num)
        self.communication += self.args.server_num * 2
        return expert_index

def sample_costs_dewa_m(args, best_expert):
    if args.dist == 'gaussian':
        cost = np.random.normal(0.6, 1, args.expert_num)
        cost[best_expert] = np.random.normal(0.2, 1)
        cost = np.clip(cost, 0, 1)
    elif args.dist == 'bernoulli':
        cost = np.random.binomial(1, 0.5, args.expert_num)
        cost[best_expert] = np.random.binomial(1, 0.25)
    costs = np.zeros((args.server_num, args.expert_num))
    for i in range(args.expert_num):
        if not args.sparse:
            costs[:, i] = np.random.uniform(0, cost[i], args.server_num)
        random_server = np.random.choice(args.server_num, 1)
        costs[random_server, i] = cost[i]
    return costs

class Server_DEWA_M:
    def __init__(self, args):
        self.args = args
        self.total_cost = 0
        self.current_cost = None

    def send(self, coord):
        sampled_experts = coord.Be
        to_send = np.zeros(self.args.expert_num)
        for expert_id in sampled_experts:
            if self.cost[expert_id] > coord.global_li[expert_id]:
                to_send[expert_id] = self.cost[expert_id]
                coord.global_li[expert_id] = self.cost[expert_id]
        return to_send

class Server_baseline:
    def __init__(self, args):
        self.args = args
        self.timestep = 0
        self.current_cost = None
        self.total_cost = 0

    def sample(self, best_expert):
        args = self.args
        if args.dist == 'gaussian':
            self.cost = np.random.normal(0.6, 1, args.expert_num)
            self.cost[best_expert] = np.random.normal(0.2, 1)
            self.cost = np.clip(self.cost, 0, 1) / args.server_num
        elif args.dist == 'bernoulli':
            self.cost = np.random.binomial(1, 0.5, args.expert_num)
            self.cost[best_expert] = np.random.binomial(1, 0.25)
            self.cost = self.cost / args.server_num
        self.total_cost += self.cost[best_expert]
        return self.cost

class Server:
    def __init__(self, args):
        self.args = args
        self.timestep = 0
        self.current_cost = None
        self.total_cost = 0
        self.alpha = None
        self.beta = None
        self.cost = None

    def sample(self, best_expert, real_cost = None):
        args = self.args
        if not args.real:
            if args.dist == 'gaussian':
                self.cost = np.random.normal(0.6, 1, args.expert_num)
                self.cost[best_expert] = np.random.normal(0.2, 1)
                self.cost = np.clip(self.cost, 0, 1) / args.server_num
            elif args.dist == 'bernoulli':
                self.cost = np.random.binomial(1, 0.5, args.expert_num)
                self.cost[best_expert] = np.random.binomial(1, 0.25)
                self.cost = self.cost / args.server_num
        else:
            self.cost = real_cost
        self.total_cost += self.cost[best_expert]
        self.alpha = np.random.binomial(1, p=args.be / args.expert_num, size=args.expert_num)
        self.beta = np.random.binomial(1, p=self.cost, size=args.expert_num)
        return self.cost, self.alpha, self.beta
    
    def refresh(self):
        assert self.alpha is not None
        args = self.args
        self.alpha = np.random.binomial(1, p=args.be / args.expert_num, size=args.expert_num)
        self.beta = np.random.binomial(1, p=self.cost, size=args.expert_num)

def compute_regret(servers):
    best_cost = 0
    for server in servers:
        best_cost += server.total_cost
    return best_cost

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--expert_num', type=int, default=100)
    parser.add_argument('--server_num', type=int, default=50)
    parser.add_argument('--T', type=int, default=100000)
    parser.add_argument('--be', type=int, default=10)
    parser.add_argument('--seed', type=int, default=2022)
    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--delta', type=float, default=1e-2)
    parser.add_argument('--case', type=str, default='dewa')
    parser.add_argument('--dist', type=str, default='gaussian')
    parser.add_argument("--sparse", default=False, action="store_true")
    parser.add_argument("--real", default=False, action="store_true")

    args = parser.parse_args()

    np.random.seed(args.seed)

    best_expert = args.expert_num // 2

    if args.real:
        real_costs = 1 - np.load('./hpob-cost-data.npy')
        args.T = real_costs.shape[0]
        args.server_num = real_costs.shape[1]
        args.expert_num = real_costs.shape[2]
        real_costs = real_costs / args.server_num

        if '-m' in args.case:
            best_expert = np.argmin(real_costs.max(axis=1).sum(axis=0))
        else:
            best_expert = np.argmin(real_costs.sum(axis=(0, 1)))

    pbar = tqdm.tqdm(total=args.T)

    if args.case == 'dewa-baseline':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_baseline(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
                for server_index in range(args.server_num):
                    servers[server_index].total_cost += costs[server_index, best_expert]
            else:
                costs = []
                for server in servers:
                    costs.append(server.sample(best_expert))
                costs = np.array(costs)

            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                costs = copy.deepcopy(sparse_costs)

            coord.dewa(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-p-baseline':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_DEWAP_baseline(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_p_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_p_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_p_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
                for server_index in range(args.server_num):
                    servers[server_index].total_cost += costs[server_index, best_expert]
            else:
                costs = []
                for server in servers:
                    costs.append(server.sample(best_expert))
                costs = np.array(costs)

            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                costs = copy.deepcopy(sparse_costs)

            coord.dewa_p(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa':
        servers = []
        for i in range(args.server_num):
            servers.append(Server(args))
        coord = Coordinator(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_' + args.dist + '_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            costs = []
            sampled_experts = []
            for server_index in range(args.server_num):
                if args.real:
                    servers[server_index].sample(best_expert, real_costs[t, server_index])
                else:
                    servers[server_index].sample(best_expert)
                costs.append(servers[server_index].cost)
                sampled_experts.append(np.logical_and(servers[server_index].alpha == 1, servers[server_index].beta == 1).astype(float))
            costs = np.array(costs)
            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                sampled_experts = []
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                    server.refresh()
                    sampled_experts.append(np.logical_and(server.alpha == 1, server.beta == 1).astype(float))
                costs = copy.deepcopy(sparse_costs)
            coord.dewa(sampled_experts, costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-p':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server(args))
        coord = Coordinator_DEWAP(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_p_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_p_' + args.dist + '_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_p_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            costs = []
            for server_index in range(args.server_num):
                if args.real:
                    servers[server_index].sample(best_expert, real_costs[t, server_index])
                else:
                    servers[server_index].sample(best_expert)
                costs.append(servers[server_index].cost)
            costs = np.array(costs)

            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                    server.refresh()
                costs = copy.deepcopy(sparse_costs)

            coord.dewa_p(costs, servers)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'ewa':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_EWA(args)
        if args.sparse:
            file_name = '../results/sparse_ewa_' + args.dist + '.txt'
        elif args.real:
            file_name = '../results/real_ewa_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/ewa_' + args.dist + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
                for server_index in range(args.server_num):
                    servers[server_index].total_cost += costs[server_index, best_expert]
            else:
                costs = []
                for server in servers:
                    costs.append(server.sample(best_expert))
                costs = np.array(costs)

            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                costs = copy.deepcopy(sparse_costs)

            coord.ewa(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'exp3':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_EXP3(args)
        if args.sparse:
            file_name = '../results/sparse_exp3_' + args.dist + '.txt'
        elif args.real:
            file_name = '../results/real_exp3_be_' + str(args.be) + '.txt'
        else:
            file_name = '../results/exp3_' + args.dist + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
                for server_index in range(args.server_num):
                    servers[server_index].total_cost += costs[server_index, best_expert]
            else:
                costs = []
                for server in servers:
                    costs.append(server.sample(best_expert))
                costs = np.array(costs)

            if args.sparse:
                agg_cost = costs.sum(axis=0)
                sparse_costs = np.zeros((args.server_num, args.expert_num))
                chosen_server = np.random.randint(0, args.server_num, args.expert_num)
                for i in range(args.expert_num):
                    sparse_costs[chosen_server[i], i] = agg_cost[i]
                for idx, server in enumerate(servers):
                    server.cost = sparse_costs[idx]
                costs = copy.deepcopy(sparse_costs)
            coord.exp3(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    best_cost = compute_regret(servers)
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-m':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_DEWA_M(args))
        coord = Coordinator_DEWA_M(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_m_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()

            coord.choose(costs)
            coord.sample()
            sent = []
            np.random.shuffle(servers)
            for server in servers:
                sent.append(server.send(coord))
            sent = np.array(sent)
            coord.receive(sent)
            coord.update()

            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-m-p':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_DEWA_M(args))
        coord = Coordinator_DEWA_M_P(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_m_p_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_m_p_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_m_p_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()

            coord.dewa_m_p(costs, servers)

            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'ewa-m':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_EWA_M(args)
        if args.sparse:
            file_name = '../results/sparse_ewa_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_ewa_m_' + str(args.be) + '.txt'
        else:
            file_name = '../results/ewa_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()
            coord.choose(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'exp3-m':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_EXP3_M(args)
        if args.sparse:
            file_name = '../results/sparse_exp3_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_exp3_m_' + str(args.be) + '.txt'
        else:
            file_name = '../results/exp3_m_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()
            coord.exp3(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-m-baseline':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_DEWA_M_baseline(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_m_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_m_baseline_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_m_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()
            coord.sample()
            coord.choose(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
    elif args.case == 'dewa-m-p-baseline':
        servers = []
        for _ in range(args.server_num):
            servers.append(Server_baseline(args))
        coord = Coordinator_DEWA_M_P_baseline(args)
        if args.sparse:
            file_name = '../results/sparse_dewa_m_p_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        elif args.real:
            file_name = '../results/real_dewa_m_p_baseline_' + str(args.be) + '.txt'
        else:
            file_name = '../results/dewa_m_p_baseline_' + args.dist + '_be_' + str(args.be) + '.txt'
        with open(file_name, 'a') as txt_file:
            txt_file.write('t, \tregret, \tcommunication\n')
        best_cost = 0
        for t in range(args.T):
            if args.real:
                costs = real_costs[t]
            else:
                costs = sample_costs_dewa_m(args, best_expert)
            for i, server in enumerate(servers):
                server.cost = costs[i]
            best_cost += costs[:, best_expert].max()
            coord.choose(costs)
            pbar.update(1)
            if args.real:
                if t > 0:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))
            else:
                if t % 100 == 99:
                    pbar.set_postfix({'AvgCost': coord.total_cost / t, 'AvgBestExpert': best_cost / t, 'Commu': coord.communication})
                    with open(file_name, 'a') as txt_file:
                        txt_file.write('%d, \t%f, \t%d\n'%(t, (coord.total_cost - best_cost)/t, coord.communication))